%% Source code for paper "The Sample Complexity of Online RL" 
% 
% (the code generates plots presented in the article)
%
%
%% initialization and setup

% tabula rasa
clear all;
close all;

% specify random seed
rng(27111989);

n=20;   % number of states
nu=5;   % number of inputs
no=100; % number of experts

% define state cost
Q=eye(n);
R=eye(nu);

% set up system matrics as detailed in corresponding appendix section -----
% (defines n/nu leaky integrators)
nmod=n/nu;
B0=zeros(nmod,1);
B0(end)=1;
Bt=kron(eye(nu),B0); % true dynamics

tmp=eye(nmod+1);
A0=tmp(2:end,1:end-1)+diag(0.8*ones(nmod,1));  
At=kron(eye(nu),A0); % true dynamics
% -------------------------------------------------------------------------


% generate candidate models -----------------------------------------------
abserror=0.1;   % absolute error
relerror=0.2;   % relative error

AUbound=At*(1+relerror)+abserror;
ALbound=At*(1-relerror)-abserror;
BUbound=Bt*(1+relerror)+abserror;
BLbound=Bt*(1-relerror)-abserror;
thetaUbound=[AUbound(:);BUbound(:)];    % upper bound on theta
thetaLbound=[ALbound(:);BLbound(:)];    % lower bound on theta
% -------------------------------------------------------------------------

% compute LQR corresponding to true dynamics (benchmark)
Kt=-dlqr(At,Bt,Q,R);


% generate dictionary of candidate models ---------------------------------
Atot={};   % system matrices
Btot={};   % system matrices
Ktot={};   % feedback gains
Ptot={};   % defines corresponding stead-state performance
%
for k=1:no
    if k==1
        % first model corresponds to true dynamics (WLOG)
        Atot={Atot{:},At};
        Btot={Btot{:},Bt};
    else
        % sample from uniform distribution over parameter range
        lambda=rand;           
        theta=thetaLbound*(1-lambda)+thetaUbound*lambda;

        % extract corresponding matrices
        [A0,B0]=getMatrices(theta,n,nu);

        % add matrices to dictionary
        Atot={Atot{:},A0};
        Btot={Btot{:},B0};
    end
    % compute feedback policy corresponding to each candidate model
    [Ktmp,Stmp,~]=dlqr(Atot{k},Btot{k},Q,R);
    Ktot={Ktot{:},-Ktmp};
    Ptot={Ptot{:},Stmp};
end


%% run experiments

T=100; % time horizon

% simulate Alg.~1
[cost,cost_opt,traj,est]=runExperimentExperts(1,Atot,Btot,Ktot,Q,R,T);

% simulate Alg.~3
[cost2,cost_opt2,traj2,est2]=runExperiment(At,Bt,thetaUbound,thetaLbound,Q,R,Kt,T);



%% extract and generate plots ---------------------------------------------
%
%% 1. plots for inspection
figure
subplot(3,1,1)
plot(est.pk')
xlabel('iterations')
ylabel('prob of different options')
subplot(3,1,2)
semilogy(abs(traj.xtraj)')
xlabel('iterations')
ylabel('state trajectory')
subplot(3,1,3)
semilogy(abs(traj.utraj)')
xlabel('iterations')
ylabel('input trajectory')

% compute accumulated costs
cost_acc=tril(ones(T))*cost';
cost_acc_opt=tril(ones(T))*cost_opt';

figure
subplot(2,1,1)
plot(cost_acc)
hold all
plot(cost_acc_opt)
ylabel('regret')
xlabel('iterations')

subplot(2,1,2)
plot(traj.sigma_uk)
ylabel('sigma_u')
xlabel('iterations')

% print out regret
fprintf('regret: %f; t: %f\n',cost_acc(end)-cost_acc_opt(end),T)

% extract parameters
theta=computeTheta(est.ik,Atot,Btot);
thetat=[At(:);Bt(:)];

% compute parameter errors
errnorm1=sqrt(sum(abs(theta-thetat*ones(1,T)).^2));
errnorm2=sqrt(sum(abs(est2.theta-thetat*ones(1,T)).^2));

%% 2. plots for article

% avoid latex interpretation of axis labels
set(0,'defaulttextinterpreter','none') 

% parameter error Alg. 1 (two norm)
figure
plot(errnorm1,'k')
ylabel('$|\theta_k-\theta^*|$')
xlabel('$k$')

% save with matlab2tikz (requires matlab2tikz)
% matlab2tikz('trajectory_theta_S1.tikz','height', '\figureheight', 'width', '\figurewidth','parseStrings',false);


%%

% parameter error Alg. 3 (two norm)
figure
plot(errnorm2,'k')
ylabel('$|\theta_k-\theta^*|$')
xlabel('$k$')

% save with matlab2tikz (requires matlab2tikz)
% matlab2tikz('trajectory_theta_S2.tikz','height', '\figureheight', 'width', '\figurewidth','parseStrings',false);


%%

% l2-norm of state trajectory Alg. 1 compared to benchmark
figure
plot(sqrt(sum(traj.xtraj.^2)),'r')
hold all
plot(sqrt(sum(traj.xtraj_opt.^2)),'k--')
ylabel('$|x_k|$')
xlabel('$k$')
legend('Alg.~1', 'Opt')

% save with matlab2tikz (requires matlab2tikz)
% matlab2tikz('trajectory_x_S1.tikz','height', '\figureheight', 'width', '\figurewidth','parseStrings',false);

%%

% l2-norm of state trajectory of Alg. 1 compared to benchmark
figure
semilogy(sqrt(sum(traj2.xtraj.^2)),'r')
hold all
semilogy(sqrt(sum(traj.xtraj_opt.^2)),'k--')
ylabel('$|x_k|$')
xlabel('$k$')
legend('Alg.~2', 'Opt')

% save with matlab2tikz (requires matlab2tikz)
% matlab2tikz('trajectory_x_S2.tikz','height', '\figureheight', 'width', '\figurewidth','parseStrings',false);





%% function for extracting system matrices (A,B) from parameter vector
%
%   input:     theta    parameter vector
%              nx,nu    state/input dimension
%   output:    (A,B)    system matrices
function [A,B]=getMatrices(theta,nx,nu)
    A=theta(1:nx^2);
    A=reshape(A,nx,nx);
    B=theta(nx^2+1:end);
    B=reshape(B,nx,nu);
end

%% function to extract system parameters given candidate model
%
% input:    ik          candidate model
%           (Atot,Btot) dictionary of models
%
% output:   theta       system parameters (A,B)
function [theta]=computeTheta(ik,Atot,Btot)
n=size(Atot{1},1);
nu=size(Btot{1},2);

theta=zeros(n^2+n*nu,length(ik));
for k=1:length(ik)
    theta(:,k)=[Atot{ik(k)}(:);Btot{ik(k)}(:)];
end
end
